/*
 * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved.
 * Licensed under the MIT License.
 */
package ai.onnxruntime;

import java.nio.ByteBuffer;
import java.util.EnumSet;

/** Configuration options for compiling ONNX models. */
public final class OrtModelCompilationOptions implements AutoCloseable {
  /** Flags representing options when compiling a model. */
  public enum OrtCompileApiFlags implements OrtFlags {
    /** Default. Do not enable any additional compilation options. */
    NONE(0),

    /**
     * Force compilation to return an error (ORT_FAIL) if no nodes were compiled. Otherwise, a model
     * with basic optimizations (ORT_ENABLE_BASIC) is still generated by default.
     */
    ERROR_IF_NO_NODES_COMPILED(1),

    /**
     * Force compilation to return an error (ORT_FAIL) if a file with the same filename as the
     * output model exists. Otherwise, compilation will automatically overwrite the output file if
     * it exists.
     */
    ERROR_IF_OUTPUT_FILE_EXISTS(1 << 1);

    /** The native value of the enum. */
    public final int value;

    OrtCompileApiFlags(int value) {
      this.value = value;
    }

    @Override
    public int getValue() {
      return value;
    }
  }

  private final long nativeHandle;
  private boolean closed = false;

  // Used to ensure the byte buffer doesn't get GC'd before the model is compiled.
  private ByteBuffer buffer;

  OrtModelCompilationOptions(long nativeHandle) {
    this.nativeHandle = nativeHandle;
  }

  /**
   * Creates a model compilation options from an existing SessionOptions.
   *
   * <p>An OrtModelCompilationOptions object contains the settings used to generate a compiled ONNX
   * model. The OrtSessionOptions object has the execution providers with which the model will be
   * compiled.
   *
   * @param env The OrtEnvironment.
   * @param sessionOptions The session options to use.
   * @return A constructed model compilation options instance.
   * @throws OrtException If the construction failed.
   */
  public static OrtModelCompilationOptions createFromSessionOptions(
      OrtEnvironment env, OrtSession.SessionOptions sessionOptions) throws OrtException {
    long handle =
        createFromSessionOptions(
            OnnxRuntime.ortApiHandle,
            OnnxRuntime.ortCompileApiHandle,
            env.getNativeHandle(),
            sessionOptions.getNativeHandle());
    return new OrtModelCompilationOptions(handle);
  }

  /**
   * Checks if the OrtModelCompilationOptions is closed, if so throws {@link IllegalStateException}.
   */
  private void checkClosed() {
    if (closed) {
      throw new IllegalStateException("Trying to use a closed OrtModelCompilationOptions.");
    }
  }

  @Override
  public void close() {
    if (!closed) {
      close(OnnxRuntime.ortCompileApiHandle, nativeHandle);
      closed = true;
    } else {
      throw new IllegalStateException("Trying to close a closed OrtModelCompilationOptions.");
    }
  }

  /**
   * Sets the file path to the input ONNX model.
   *
   * <p>The input model's location must be set either to a path on disk with this method, or by
   * supplying an in-memory reference with {@link #setInputModelFromBuffer}.
   *
   * @param inputModelPath The path to the model on disk.
   * @throws OrtException If the set failed.
   */
  public void setInputModelPath(String inputModelPath) throws OrtException {
    checkClosed();
    setInputModelPath(
        OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, inputModelPath);
  }

  /**
   * Uses the supplied buffer as the input ONNX model.
   *
   * <p>The input model's location must be set either to an in-memory reference with this method, or
   * by supplying a path on disk with {@link #setInputModelPath(String)}.
   *
   * <p>If the {@link ByteBuffer} is not direct it is copied into a direct buffer. In either case
   * this object holds a reference to the buffer to prevent it from being GC'd.
   *
   * @param inputModelBuffer The buffer.
   * @throws OrtException If the buffer could not be set.
   */
  public void setInputModelFromBuffer(ByteBuffer inputModelBuffer) throws OrtException {
    checkClosed();
    if (!inputModelBuffer.isDirect()) {
      // if it's not a direct buffer, copy it.
      buffer = ByteBuffer.allocateDirect(inputModelBuffer.remaining());
      int tmpPos = inputModelBuffer.position();
      buffer.put(inputModelBuffer);
      buffer.rewind();
      inputModelBuffer.position(tmpPos);
    } else {
      buffer = inputModelBuffer;
    }
    int bufferPos = buffer.position();
    int bufferRemaining = buffer.remaining();
    setInputModelFromBuffer(
        OnnxRuntime.ortApiHandle,
        OnnxRuntime.ortCompileApiHandle,
        nativeHandle,
        buffer,
        bufferPos,
        bufferRemaining);
  }

  /**
   * Sets the file path for the output compiled ONNX model.
   *
   * <p>If this is unset it will append `_ctx` to the file name, e.g., my_model.onnx becomes
   * my_model_ctx.onnx.
   *
   * @param outputModelPath The output model path.
   * @throws OrtException If the path could not be set.
   */
  public void setOutputModelPath(String outputModelPath) throws OrtException {
    checkClosed();
    setOutputModelPath(
        OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, outputModelPath);
  }

  /**
   * Optionally sets the file that stores initializers for the compiled ONNX model. If unset then
   * initializers are stored inside the model.
   *
   * <p>Only initializers for nodes that were not compiled are stored in the external initializers
   * file. Compiled nodes contain their initializer data within the `ep_cache_context` attribute of
   * EPContext nodes.
   *
   * @see OrtModelCompilationOptions#setEpContextEmbedMode
   * @param outputExternalInitializersPath Path to the file.
   * @param sizeThreshold Initializers larger than this threshold are stored in the file.
   * @throws OrtException If the path could not be set.
   */
  public void setOutputExternalInitializersPath(
      String outputExternalInitializersPath, long sizeThreshold) throws OrtException {
    checkClosed();
    // check positive
    setOutputExternalInitializersPath(
        OnnxRuntime.ortApiHandle,
        OnnxRuntime.ortCompileApiHandle,
        nativeHandle,
        outputExternalInitializersPath,
        sizeThreshold);
  }

  /**
   * Enables or disables the embedding of EPContext binary data into the ep_cache_context attribute
   * of EPContext nodes.
   *
   * <p>Defaults to false. When enabled, the `ep_cache_context` attribute of EPContext nodes will
   * store the context binary data, which may include weights for compiled subgraphs. When disabled,
   * the `ep_cache_context` attribute of EPContext nodes will contain the path to the file
   * containing the context binary data. The path is set by the execution provider creating the
   * EPContext node.
   *
   * <p>For more details see the <a
   * href="https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html">EPContext design
   * document.</a>
   *
   * @param embedEpContext True to embed EPContext binary data into the EPContext node's
   *     ep_cache_context attribute.
   * @throws OrtException If the set operation failed.
   */
  public void setEpContextEmbedMode(boolean embedEpContext) throws OrtException {
    checkClosed();
    setEpContextEmbedMode(
        OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, embedEpContext);
  }

  /**
   * Sets the specified compilation flags.
   *
   * @param flags The compilation flags.
   * @throws OrtException If the set operation failed.
   */
  public void setCompilationFlags(EnumSet<OrtCompileApiFlags> flags) throws OrtException {
    checkClosed();
    setCompilationFlags(
        OnnxRuntime.ortApiHandle,
        OnnxRuntime.ortCompileApiHandle,
        nativeHandle,
        OrtFlags.aggregateToInt(flags));
  }

  /**
   * Compiles the ONNX model with the configuration described by this instance of
   * OrtModelCompilationOptions.
   *
   * @throws OrtException If the compilation failed.
   */
  public void compileModel() throws OrtException {
    checkClosed();
    // Safe as the environment must exist to create one of these objects.
    OrtEnvironment env = OrtEnvironment.getEnvironment();
    compileModel(
        OnnxRuntime.ortApiHandle,
        OnnxRuntime.ortCompileApiHandle,
        env.getNativeHandle(),
        nativeHandle);
  }

  private static native long createFromSessionOptions(
      long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException;

  private static native void close(long compileApiHandle, long nativeHandle);

  private static native void setInputModelPath(
      long apiHandle, long compileApiHandle, long nativeHandle, String inputModelPath)
      throws OrtException;

  private static native void setInputModelFromBuffer(
      long apiHandle,
      long compileApiHandle,
      long nativeHandle,
      ByteBuffer inputBuffer,
      long bufferPos,
      long bufferRemaining)
      throws OrtException;

  private static native void setOutputModelPath(
      long apiHandle, long compileApiHandle, long nativeHandle, String outputModelPath)
      throws OrtException;

  private static native void setOutputExternalInitializersPath(
      long apiHandle,
      long compileApiHandle,
      long nativeHandle,
      String externalInitializersPath,
      long sizeThreshold)
      throws OrtException;

  private static native void setEpContextEmbedMode(
      long apiHandle, long compileApiHandle, long nativeHandle, boolean embedEpContext)
      throws OrtException;

  private static native void setCompilationFlags(
      long apiHandle, long compileApiHandle, long nativeHandle, int flags) throws OrtException;

  private static native void compileModel(
      long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException;
}
